
import math
import time

import torch
from thop import profile
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import init

from layer import CS_loss, Reconstruction_loss, S_weights, T_weights


class GCNLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(GCNLayer, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
        init.xavier_uniform_(self.weight, gain=math.sqrt(2.0))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

    def forward(self, x, adj):
        support = torch.matmul(x, self.weight)
        output = torch.matmul(adj, support)
        return output



class GCN(nn.Module):
    def __init__(self, num_features, hidden_dim):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(num_features, hidden_dim)
        self.gcn2 = GCNLayer(hidden_dim, hidden_dim)

    def forward(self, X, adj):
        X = self.gcn1(X, adj)
        X = F.relu(X)
        X = self.gcn2(X, adj)
        return X


class TPGCN(nn.Module):
    def __init__(self, num_features, hidden_dim):
        super(TPGCN, self).__init__()
        self.gcn1 = GCNLayer(num_features, hidden_dim)
        self.cnn1 = nn.Conv2d(1,1,3,padding=1)
        self.gcn2 = GCNLayer(hidden_dim, hidden_dim)

    def forward(self,X,adj):
        out1 = self.gcn1(X,adj)
        out1 = torch.unsqueeze(out1,dim=1)
        out2 = self.cnn1(out1)
        out2 = torch.squeeze(out2,dim=1)
        out = self.gcn2(out2,adj)
        return out

# X = torch.randn(20,90,45)
# adj = torch.randn(20,90,90)
# model = TPGCN(45,16)
# out = model(X,adj)
# print(out.shape)

class STPD(nn.Module):
    def __init__(self, num_features, hidden_dim):
        super(STPD, self).__init__()
        self.num_features = num_features
        self.hidden_dim = hidden_dim
        self.S1_GCN = GCN(self.num_features, self.hidden_dim)
        self.T1_GCN = GCN(self.num_features, self.hidden_dim)
        self.decoder1 =nn.Linear(self.hidden_dim, 90)

        # self.S2_GCN = GCN(self.num_features, self.hidden_dim)
        # self.T2_GCN = GCN(self.num_features, self.hidden_dim)
        self.decoder2 = nn.Linear(self.hidden_dim, 90)
        #

    def forward(self, A_1, A_2, X_1, X_2):
        S1_out = self.S1_GCN(X_1, A_1)
        T1_out = self.T1_GCN(X_1, A_1)
        rec_ST1 = S1_out + T1_out
        # rec_ST1 = self.decoder1(ST1)
        S2_out = self.S1_GCN(X_2, A_2)
        T2_out = self.T1_GCN(X_2, A_2)
        rec_ST2 = S2_out + T2_out
        # rec_ST2 = self.decoder2(ST2)
        return S1_out, T1_out, rec_ST1, S2_out, T2_out, rec_ST2


# a = torch.randn(12,90,90)
# x = torch.randn(12,90,32)
#
# model = STPD(32,16)
# out,_,_,_,_,_ = model(a,a,x,x)
# print(out.shape)



class model_all(nn.Module):
    def __init__(self, num_features, hidden_dim, num_wins):
        super(model_all, self).__init__()
        self.num_features = num_features
        self.hidden_dim = hidden_dim
        self.num_wins = num_wins

        self.STPD_all = nn.ModuleList([STPD(self.num_features, self.hidden_dim) for i in range(self.num_wins - 1)])
        self.STPFGC_all = nn.ModuleList([TPGCN(self.hidden_dim, self.hidden_dim) for i in range(self.num_wins)])
        self.f1 = nn.Flatten()
        self.l1 = nn.Linear(self.hidden_dim * 90*2, 32)
        self.bn1 = nn.BatchNorm1d(32)
        # self.d1 = nn.Dropout(p=0.5)
        self.l2 = nn.Linear(32, 2)
        # self.logs = nn.logsoftmax(dim=1)

    def forward(self, fea_all, net_all):
        spa_nets = []

        tem_nets = []
        loss_CC = 0
        loss_REC = 0
        for t, STPD_layer in enumerate(self.STPD_all):
            fea1 = fea_all[:, t, :, :]
            fea2 = fea_all[:, t + 1, :, :]
            net1 = net_all[:, t, :, :]
            net2 = net_all[:, t + 1, :, :]
            S1_out, T1_out, rec_ST1, S2_out, T2_out, rec_ST2 = STPD_layer(net1, net2, fea1, fea2)
            loss_CC1 = torch.abs(CS_loss(S1_out, T1_out))
            loss_CC2 = torch.abs(CS_loss(S2_out, T2_out))
            loss_CC3 = - torch.abs(CS_loss(S1_out, S2_out))
            loss_CC = loss_CC + loss_CC1 + loss_CC2 + loss_CC3
            loss_rec1 = Reconstruction_loss(torch.mean(net1,dim=(0,2)), torch.mean(rec_ST1,dim=(0,2)))
            # print("loss_rec1",loss_rec1)
            loss_rec2 = Reconstruction_loss(torch.mean(net2,dim=(0,2)), torch.mean(rec_ST2,dim=(0,2)))
            loss_REC = loss_REC + loss_rec1 + loss_rec2
            if t < self.num_wins - 2:
                spa_nets.append(S1_out)
                tem_nets.append(T1_out)
            else:
                spa_nets.append(S1_out)
                tem_nets.append(T1_out)
                spa_nets.append(S2_out)
                tem_nets.append(T2_out)
        spa_nets = torch.stack(spa_nets, dim=1)
        tem_nets = torch.stack(tem_nets, dim=1)
        SH_weight = S_weights(spa_nets, self.hidden_dim)
        SH_weight = SH_weight.unsqueeze(1)
        spa_nets = spa_nets * SH_weight
        TH_weight = T_weights(tem_nets, self.hidden_dim)
        TH_weight = TH_weight.unsqueeze(1)
        tem_nets = tem_nets * TH_weight
        st_fea = []
        fea_sum1 = []
        fea_sum2 = []
        for t, STPFGC_all_layer in enumerate(self.STPFGC_all):
            if t == 0:
                fea1 = tem_nets[:, t, :, :]
                net1 = spa_nets[:, t, :, :]
                net_adj1 = net_all[:, t, :, :]
                st_fea1 = STPFGC_all_layer(fea1, net_adj1)
                fea_sum1.append(st_fea1)
                st_fea2 = STPFGC_all_layer(net1, net_adj1)
                fea_sum2.append(st_fea2)
                st_feah = torch.cat((st_fea1, st_fea2), dim=2)
                st_fea.append(st_feah)

            else:
                fea1 = tem_nets[:, t, :, :]
                net1 = spa_nets[:, t, :, :]
                net_adj1 = net_all[:, t, :, :]
                st_fea1 = STPFGC_all_layer(fea1+tem_nets[:, t-1, :, :], net_adj1)
                fea_sum1.append(st_fea1)
                st_fea2 = STPFGC_all_layer(net1+spa_nets[:, t-1, :, :], net_adj1)
                fea_sum2.append(st_fea2)
                st_feah = torch.cat((st_fea1, st_fea2), dim=2)
                st_fea.append(st_feah)




        st_fea = torch.stack(st_fea, dim=1)
        st_fea = torch.mean(st_fea, dim=1)
        st_fea = self.f1(st_fea)
        out = self.bn1(self.l1(st_fea))
        out = F.log_softmax(self.l2(out), dim=1)

        return out, loss_CC, loss_REC



fea_all = torch.randn(8, 6, 90, 80).cuda()
net_all = torch.randn(8, 6, 90, 90).cuda()
model = model_all(80, 16, 6).cuda()
out, loss_CC, loss_REC = model(fea_all, net_all)
print(out)
model.eval()

flops, params = profile(model, (fea_all, net_all,))
print('flops: ', flops, 'params: ', params)
print('flops: %.6f M, params: %.6f M' % (flops / 8000000.0, params / 1000000.0))

# 预热GPU，确保其处于稳定工作状态
for _ in range(5):
    model(fea_all, net_all)
torch.cuda.synchronize()

# 计时开始
start_time = time.time()

# 进行多次推理并计算平均时间
n_iterations = 100
total_time = 0
for _ in range(n_iterations):
    with torch.no_grad():
        output = model(fea_all, net_all)
    torch.cuda.synchronize()
    total_time += time.time() - start_time
    start_time = time.time()  # 重置起始时间，以便测量下一次迭代

# 计算平均推理时间
average_time = total_time / n_iterations
print(f'Average inference time: {average_time:.6f} seconds')
